package org.jetbrains.plugins.scala.annotator

import com.intellij.psi.{PsiDocumentManager, PsiElement, PsiFile}
import com.intellij.testFramework.fixtures.CodeInsightTestFixture
import org.intellij.lang.annotations.Language
import org.jetbrains.plugins.scala.annotator.hints.AnnotatorHints
import org.jetbrains.plugins.scala.extensions.{IterableOnceExt, PsiElementExt, StringExt}
import org.jetbrains.plugins.scala.util.assertions.MatcherAssertions
import org.junit.Assert.fail

//TODO: use better name for the base class?
// Current naming is confusing
// This test class doesn't exactly test the highlighting which you see in the editor.
// It tests annotations generated by ScalaAnnotator. The annotator is applied for all elements in the file recursively.
// (see `org.jetbrains.plugins.scala.annotator.ScalaHighlightingTestLike.messagesFromScalaCode`)
// In general case it can be different from what you see in the editor
// For example if you paste this code in the editor:
// ```scala
// UnresolvedObject.unresolvedMethod`
// ```
// you will see 1 error indicating that "UnresolvedObject" is unresolved
// (see SCL-15138, this was done in one of it's sub-tickets)
// However if you use current base test class you will see 2 errors for "UnresolvedObject" and "unresolvedMethod"
// For the generic highlighting test one should use `myFixture.doHighlighting()` or some base class which uses it.
trait ScalaHighlightingTestLike extends MatcherAssertions {
  protected def getFixture: CodeInsightTestFixture

  //////////////////////////////////////////////////
  // Assertions START
  //////////////////////////////////////////////////

  protected def assertNoErrors(@Language("Scala") code: String): Unit =
    assertErrors(code, Nil: _*)

  protected def assertErrors(@Language("Scala") code: String, messages: Message*): Unit =
    assertErrorsText(code, messages.mkString("\n"))

  protected def assertErrorsWithHints(@Language("Scala") code: String, messages: Message*): Unit =
    assertErrorsWithHintsText(code, messages.mkString("\n"))

  protected def assertMessages(@Language("Scala") code: String, messages: Message*): Unit =
    assertMessagesText(code, messages.mkString("\n"))

  protected def assertNoMessages(@Language("Scala") code: String): Unit =
    assertMessages(code, Nil: _*)

  protected def assertErrorsText(@Language("Scala") code: String, messagesConcatenated: String): Unit = {
    val actualMessages = errorsFromScalaCode(code)
    assertMessagesTextImpl(messagesConcatenated, actualMessages)
  }

  protected def assertErrorsWithHintsText(@Language("Scala") code: String, messagesConcatenated: String): Unit = {
    val actualMessages = errorsWithHintsFromScalaCode(code)
    assertMessagesTextImpl(messagesConcatenated, actualMessages)
  }

  protected def assertMessagesText(@Language("Scala") code: String, messagesConcatenated: String): Unit = {
    val actualMessages = messagesFromScalaCode(code)
    assertMessagesTextImpl(messagesConcatenated, actualMessages)
  }

  protected def assertMessagesTextImpl(
    expectedMessagesConcatenated: String,
    actualMessages: Seq[Message],
  ): Unit = {
    // handle windows '\r', ignore empty lines
    val messagesConcatenatedClean =
      expectedMessagesConcatenated.withNormalizedSeparator.replaceAll("\\n\\n+", "\n").trim

    val actualMessagesConcatenated = actualMessages.mkString("\n")
    assertEqualsFailable(
      messagesConcatenatedClean,
      actualMessagesConcatenated
    )
  }

  //////////////////////////////////////////////////
  // Assertions END
  //////////////////////////////////////////////////

  //////////////////////////////////////////////////
  // Annotations extraction logic START
  //////////////////////////////////////////////////

  protected def errorsFromScala3Code(@Language("Scala 3") scalaFileText: String): List[Message.Error] =
    errorsFromScalaCode(scalaFileText)

  protected def errorsFromScalaCode(@Language("Scala") scalaFileText: String): List[Message.Error] =
    errorsFromScalaCode(scalaFileText, s"dummy.scala")

  protected def errorsWithHintsFromScalaCode(@Language("Scala") scalaFileText: String): List[Message] = {
    configureFile(scalaFileText, s"dummy.scala")
    errorsWithHintsFromScalaCode(getFixture.getFile)
  }

  protected def messagesFromScalaCode(@Language("Scala") scalaFileText: String): List[Message] = {
    configureFile(scalaFileText, s"dummy.scala")
    messagesFromScalaCode(getFixture.getFile)
  }

  protected def errorsFromScalaCode(@Language("Scala") scalaFileText: String, fileName: String): List[Message.Error] = {
    configureFile(scalaFileText, fileName)
    errorsFromScalaCode(getFixture.getFile)
  }

  private var filesCreated: Boolean = false

  private def configureFile(@Language("Scala") scalaFileText: String, fileName: String): Unit = {
    if (filesCreated)
      fail("Don't add files 2 times in a single test")

    getFixture.configureByText(fileName, scalaFileText)

    filesCreated = true
  }

  protected def errorsFromScalaCode(file: PsiFile): List[Message.Error] =
    nonEmptyMessagesFromScalaCode(file).filterByType[Message.Error]

  protected def errorsWithHintsFromScalaCode(file: PsiFile): List[Message] = {
    val errors = nonEmptyMessagesFromScalaCode(file).filterByType[Message.Error]

    val hints = file.elements
      .flatMap(AnnotatorHints.in(_).toSeq.flatMap(_.hints))
      .map(convertHintToTestHint)
      .toList

    hints ::: errors
  }

  private def convertHintToTestHint(hint: org.jetbrains.plugins.scala.annotator.hints.Hint): Message.Hint =
    Message.Hint(
      hint.element.getText,
      hint.parts.map(_.string).mkString,
      offsetDelta = hint.offsetDelta
    )

  private def nonEmptyMessagesFromScalaCode(file: PsiFile): List[Message] =
    messagesFromScalaCode(file).filter(m => m.element != null && m.message != null)

  protected def messagesFromScalaCode(file: PsiFile): List[Message] = {
    PsiDocumentManager.getInstance(getFixture.getProject).commitAllDocuments()

    val annotationHolder: AnnotatorHolderMock = new AnnotatorHolderMock(file)

    file.depthFirst().foreach(annotate(_)(annotationHolder))

    annotationHolder.annotations
  }

  protected def annotate(element: PsiElement)
                        (implicit holder: ScalaAnnotationHolder): Unit =
    new ScalaAnnotator().annotate(element)

  //////////////////////////////////////////////////
  // Annotations extraction logic END
  //////////////////////////////////////////////////
}
